#!/usr/bin/env
import format_tools as ft
import re, sys, string
from textwrap import wrap
from itertools import product
from copy import deepcopy
import math

__author__ = 'Gordon Sun'

'''
Global variables containing primary alphabet and pairing tables.  Note that nucleotide pairings are A-T and G-C as found
in biological DNA sequences, while codon pairings are strings of 3 nucleotides paired to an amino acid (designated by its
one letter code.  These pairing tables are read into the program in the main function.
'''
global nt_alphabet, aa_alphabet, nt_pairing, codon_pairing, IGEM_restriction
nt_alphabet = 'atgc'
aa_alphabet = 'galmfwkqespvicyhrndt*'
codon_pairing = ft.read_table('Codon_Lib')
nt_pairing = ft.read_table('NT_Lib')
restriction_sites = ft.read_table('restriction_enzymes')
fragment_overlap_length = 50


# takes in a char and pairs it to its complement
# I: <char> char nt
# O: <char> char pairing
def return_nt_complement(nt):
    '''returns nucleotide complement'''
    for x in xrange(len(nt_pairing)):
        if nt == nt_pairing[x][0]:
            return nt_pairing[x][1]


# takes in aa and pairs it to its AA complement or list of codons
# I: <string> nt codon or <char> aa, and <list> pairing map
# O: <char> AA letter abbreviation or <list> codons <string>
def return_aa_match(aa_or_codon):
    '''returns either codon or aa match given a either an aa or codon respectively'''
    if len(aa_or_codon) > 1:
        for y in xrange(len(codon_pairing)):
            if aa_or_codon.lower() == codon_pairing[y][0]:
                return codon_pairing[y][1]
    else:
        matches = []
        for p in xrange(len(codon_pairing)):
            if aa_or_codon.lower() == codon_pairing[p][1]:
                matches.append(codon_pairing[p][0])
        return matches


# finds all instances of a search term within a search body and provides the start index of each of the terms.
# I: <string> search term, <string> search body
# O: <list>  of indicies
def find_all_instances(searchterm, searchbody):
    '''returns a list of indicies at which the search term is found in a body of txt, returns empty array if null'''
    return [m.start() for m in re.finditer(searchterm, searchbody)]


# takes in a DNA sequence and returns its complement
# I:  <string> NT sequence
# O:  <string> complimentary NT sequence
def return_complement(sequence):
    '''returns the dna complement of a DNA sequence'''
    complement = list(sequence)
    for h, item in enumerate(sequence):
        complement[h] = return_nt_complement(sequence[h])
    return ''.join(complement)


# Function determines if two DNA sequences are complimentary
# I: <list> sequences, and pair map
# O: <bool> determines if the sequences are complimentary
def is_complement(seq1, seq2):
    '''returns true or false if the sequences are complimentary to one another'''
    valid = True
    print seq2
    print seq1
    if len(seq1) != len(seq2):
        print "sequences are different lengths"
        return False
    else:
        for b in xrange(len(seq1)):
            if seq1[b] != return_nt_complement(seq2[b]):
                valid = False
                print "Mismatch found at position %s with seq1: %s aligned with seq2: %s" % (b, seq1[b], seq2[b])
                break
        return valid


# function determines if a sequences contains allowed characters
# I: <string> word and <string> alphabet
# O: Bool true or false if word is made of legit alphabet
def valid_word(word, alphabet):
    '''returns <bool> on whether the word has valid/allowable letters'''
    valid = True
    for x in xrange(len(word)):
        if alphabet.find(word[x]) == -1:
            valid = False
            break
    return valid


# function that finds ALL possible proteins in a provided sequence.
# I: Input a start and End aa (1 base) sequence <string>, a codon map <list>, and the target sequence <string>
# O: Outputs a <list> of proteins found in the sequence
def find_prots(start_c, end_c, seq):
    '''finds all potential 3 nt based codon proteins in the DNA sequence and returns all possible proteins'''
    sequence = ''.join(seq).lower()
    start_codons = sorted(return_aa_match(start_c))
    stop_codons = sorted(return_aa_match(end_c))
    found_end = []
    found_start = []
    for u, item in enumerate(stop_codons):
        found_end.extend(find_all_instances(stop_codons[u], sequence))
    for u2, item in enumerate(start_codons):
        found_start.extend(find_all_instances(start_codons[u2], sequence))
    found_end.sort()
    found_start.sort()
    found_prots = []
    for x, x_item in enumerate(found_start):
        for y, y_item in enumerate(found_end):
            if (found_end[y] - found_start[x]) % 3 == 0 and (found_end[y] - found_start[x]) > 0:
                found_prots.append(''.join(sequence[found_start[x]:found_end[y] + 3]))
    return found_prots


# function returns the frequency of a 3nt codon
# I: <str> codon input, <list> for reading in the value
# O: <list> frequency of codon usage, codon, and associated aa
def get_codon_freq(codon_input, org_codon_table):
    '''returns the codon frequency of organism provided a codon'''
    if len(codon_input) != 3 and not valid_word(codon_input.lower(), nt_alphabet):
        print "not a fuckin codon"
        return []
    else:
        for u in xrange(len(org_codon_table)):
            if codon_input.lower() == org_codon_table[u][0]:
                return org_codon_table[u]


# function returns the frequency of an aa
# I: <str> aa input, <list> for reading in the value
# O: <list> frequency of codon usage, codon, and associated aa
def get_AA_freq(AA_input, org_codon_table):
    '''returns the codon frequency of organism provided an aa, returns all associated codons'''
    if not valid_word(AA_input.lower(), aa_alphabet):
        print "not a fuckin aa"
        return []
    else:
        valid_codons = []
        for h in xrange(len(org_codon_table)):
            if AA_input.lower() == org_codon_table[h][1]:
                valid_codons.append(org_codon_table[h])
        return sorted(valid_codons, key=lambda x: x[3], reverse=True)


# function reads in a codon, and the ranking of usage (1-...) and returns the optimal codon
# I: reads in <str> nt codon len of 3, the search lists, and the ranking <int>
# O: <str> 3nt codon
def optimal_codon(codon, number, org_codon_table):
    '''function reads in a codon, and the ranking of usage (1-...) and returns the optimal codon'''
    if len(codon) == 3 and valid_word(codon, nt_alphabet):
        aa = return_aa_match(codon)
        codon_freq = get_AA_freq(aa, org_codon_table)
        if number > len(codon_freq):
            print ">Error: Index error accessing codon frequency table"
            sys.exit()
        else:
            return codon_freq[number - 1][0]
    else:
        print ">Error: Index error accessing codon - insufficient length"
        sys.exit()


# function returns the most optimized codon sequence for a protein
# I:  <str> sequence and <list> organism codon table
# O: <str> codon sequence
def optimized_prot(sequence, org_codon_table):
    if len(sequence) % 3 != 0:
        text = ">Error: Sequence % 3 != 0, codon length error"
        print text
        return text
        sys.exit()
    else:
        opt_seq = []
        for nt_num, item in enumerate(sequence):
            if nt_num % 3 == 0:
                opt_seq.append(optimal_codon(sequence[nt_num:nt_num + 3], 1, org_codon_table))
        return ''.join(opt_seq)


# translates a nucleotide sequence into a protein sequence
# I: <str> nt sequence
# O: <str> protein sequence
def translator(sequence):
    if len(sequence) % 3 != 0:
        text = ">Error: Sequence % 3 != 0, codon length error"
        print text
        return text
        sys.exit()
    else:
        prot_seq = ''
        for nt_num, item in enumerate(sequence):
            if nt_num % 3 == 0:
                prot_seq += return_aa_match(sequence[nt_num:nt_num + 3])
        return prot_seq


# function returns all the positions of restriction sites in a sequence
# I: accepts a <str> sequence
# O: returns a 2d <list> consisting of restriction enzymes and their sites.
def find_restriction_sites(sequence):
    sequence = sequence.lower()
    restriction_enzymes = deepcopy(restriction_sites)
    for f, item in enumerate(restriction_enzymes):
        item.append(find_all_instances(item[1], sequence))
    if len(restriction_enzymes[0]) > 3:
        for row in restriction_enzymes:
            del row[len(row) - 1]
            # really weird situation where find_restriction_sites appends the restriction indicies on more than once...unsure why it is so
    return restriction_enzymes


# function evaluates if there are any restriction sites left in the sequence based on the restriction table passed by find restrictions
# I: <str> sequence data
# O: <bool> true if there are no restriction sites, false if yes
def no_restrictions(sequence):
    restriction_table = find_restriction_sites(sequence)
    valid = True
    for i, item in enumerate(restriction_table):
        if restriction_table[i][2]:
            valid = False
            break
    return valid


# Function determines if the sequence is a protein or nt sequence
# I: sequence in <str> format
# O: <int> 1 if DNA, 2 if prot,0 if unknown
def DNAorProt(sequence):
    if valid_word(sequence, nt_alphabet):
        return 1
    elif valid_word(sequence, aa_alphabet):
        return 2
    else:
        return 0


# function takes in an index and determines the first multiple of 3 before it to determine the start of a codon
# I: <int> index
# O: <int> output index < input index
def def_fwd_read_frame(index):
    if index % 3 == 0:
        return index
    else:
        return index - (index % 3)


# function takes in an index and determines the first multiple of 3 before it to determine the end of a codon
# I: <int> index
# O: <int> output index > input index
def def_bck_read_frame(index):
    if (index + 1) % 3 == 0 and index > 0:
        return index + 1
    else:
        return 3 * ((index / 3) + 1)


# takes in a PROTEIN SEQ and returns the optimal DNA match. Does not remove restriction sites
# I: <str> sequence and <list> codon usage frequency table
# O: <str> returns the optimized DNA sequence for a protein sequence
def rev_translator(sequence, codon_freq_table):
    seq_possibilities = []
    for j, i in enumerate(sequence):
        possible_codons = return_aa_match(i)
        for c, codon in enumerate(possible_codons):
            possible_codons[c] = get_codon_freq(codon, codon_freq_table)
        sorted_p_codon = sorted(possible_codons, key=lambda x: x[2], reverse=True)
        ranked_codon = [item[0] for item in sorted_p_codon]
        seq_possibilities.append(ranked_codon)  # uses return_aa_match in AA mode
    return seq_possibilities


# function takes in a 3nt codon and returns the codons that also yield the same protein sorted by frequency of usage
# I: <str> 3nt codon, <list> codon frequency table
# O: <list> list of codons with their frequency usage
def codon_possibilities(codon, codon_freq_table):
    aa = return_aa_match(codon)
    codons = return_aa_match(aa)
    for c, codon in enumerate(codons):
        codons[c] = get_codon_freq(codon, codon_freq_table)
    sorted_codon = sorted(codons, key=lambda x: x[2], reverse=True)
    return sorted_codon


# given a sequence, returns all sequence codon options that yield the same protein sequence in
# order of greatest frequency of useage
# I: <str> sequence, <list> codon frequency table
# O: returns the options <list>
def codon_list_options(sequence, codon_freq_table):
    codon_options = []
    sequence = sequence.lower()
    seq_split = wrap(sequence, 3)
    for i, item in enumerate(seq_split):
        position = []
        possibles = codon_possibilities(item, codon_freq_table)
        for row in possibles:
            position.append(row[0])
        codon_options.append(position)
    return codon_options


# companion function of all_variants, calculates the product of codon frequencies to determine which codon cassette is
# most probable.  This is appended onto the all variants cassette in the end.
# I: sequence tuple ('aaa','ggg','ttt') for ONE sequence, does not process more than one sequence at a time
# O: <float> probability
def seq_probability(seq, codon_freq_table):
    prob = 1
    for x, item in enumerate(seq):
        prob *= float(get_codon_freq(item, codon_freq_table)[2])
    return prob


# function given an array of choices like array = [['a','b','c'],['d','e'],['f'],['g','h']] will generate all possible
# permutations within the defined choices i.e.
# ['adfg', 'adfh', 'aefg', 'aefh', 'bdfg', 'bdfh', 'befg', 'befh', 'cdfg', 'cdfh', 'cefg', 'cefh']
# also returns the probability of the choice ranking by frequency
# I: array of choices
# O: list of all choices
def all_variants(choices, codon_freq_table):
    array_of_choices = []
    choices_prob_rank = []
    string_list = []
    for choices in product(*choices):
        array_of_choices.append(choices)
        choices_prob_rank.append([''.join(choices)])
    for choice, item in enumerate(array_of_choices):
        prob = seq_probability(item, codon_freq_table)
        choices_prob_rank[choice].append(prob)
    choices_prob_rank = sorted(choices_prob_rank, key=lambda x: x[1], reverse=True)
    for row in choices_prob_rank:
        del row[len(row) - 1]
    for h, item in enumerate(choices_prob_rank):
        string_list.append(''.join(item[0]))
    return string_list


# function takes in a protein or DNA sequence, optimizes it, and removes any restriction sites in the protein.
# I: <str> sequence to be optimized and the <list> codon frequencies it is to be optimized to
# O: returns <str> of optimized sequence.
def prot_optimization(sequence, codon_freq):
    seq_type = DNAorProt(sequence)
    if seq_type != 0 and seq_type < 3:
        if seq_type == 1:
            prottranslation = translator(sequence)
            dnatranslation = ''.join([row[0] for row in rev_translator(prottranslation, codon_freq)])
        else:
            dnatranslation = ''.join([row[0] for row in rev_translator(sequence, codon_freq)])
        while not no_restrictions(dnatranslation):
            rx_sites = find_restriction_sites(dnatranslation)
            rx_sites = [sublist for sublist in rx_sites if sublist[2]]  # select only relevant restriction sites
            for element in rx_sites[0][len(rx_sites[0]) - 1]:
                fwd_index = def_fwd_read_frame(element)
                bck_index = def_bck_read_frame(element + len(rx_sites[0][1]) - 1)
                options = codon_list_options(dnatranslation[fwd_index:bck_index], codon_freq)
                list_of_options = all_variants(options, codon_freq)
                old_str = deepcopy(dnatranslation[fwd_index:bck_index])
                for h, item in enumerate(list_of_options):
                    if no_restrictions(item):
                        dnatranslation = string.replace(dnatranslation, old_str, item)
        return dnatranslation
    else:
        print ">Error: Sequence not recognized; unable to optimize protein->DNA seq"
        return "Sequence not Recognized, unable to optimize"
        sys.exit()


# function determines the melting temperature of a sequence given its sequence and its salt concentration
# I: takes a <str> sequence and <float> salt concentration
# O: outputs a <float> melting temperature
def melting_temp(fragment, salt_conc):
    fragment = fragment.lower()
    A = fragment.count('a')
    G = fragment.count('g')
    C = fragment.count('c')
    T = fragment.count('t')
    salt_correction = math.log10(salt_conc)
    if len(fragment) <= 13:
        return ((A + T) * 2.0) + ((G + C) * 4.0) - 16.6 * (math.log10(0.050) - salt_correction)
    else:
        return 100.5 + 41 * ((1.0 * (G + C)) / (1.0 * (A + C + G + T))) - (
            820.0 / (A + C + G + T)) + 16.6 * salt_correction


# function takes in two fragments that are in the order AC and determines the middle x*2 nucleotides for the adjacent
# fragment junctions.
# I: takes in two fragment sequences <str>
# O: Returns a <str> containing the rear of fragment 1, front of fragment 2 to determine the junction
def designate_overlap(fragment1, fragment2):
    frag1_rear = fragment1[len(fragment1) - fragment_overlap_length:len(fragment1)]
    frag2_front = fragment2[0:fragment_overlap_length]
    return frag1_rear + frag2_front


# function takes in a junction between two fragments, a target melting temperature, and the salt concentration of the
# gibson reaction mix and outputs the ideal overlap for the junction that will form the basis for the primer. Will auto
# adjust for tolerance for minimum temperature difference between calculated Tm and desired Tm.
# I: <str> junction, <float> melting temperature, <float> salt concentration
# O: Outputs <list> containing <str> overlap,<str> overlap with | marker indicating position <float> Tm
def pick_overlap(junction, Tm, salt_conc):
    overlap_len_range = range(15, 26)
    middle = len(junction) / 2
    junction_search_ranges = []
    possible_overlaps = []
    ranked_overlaps = []
    for f, item in enumerate(overlap_len_range):
        junction_search_ranges.append(junction[middle - item:middle + item])
    for search_range in junction_search_ranges:
        search_box_size = len(search_range) / 2
        for y in xrange(search_box_size + 1):
            possible_overlaps.append(search_range[y:y + search_box_size])
    for q, overlap in enumerate(possible_overlaps):
        ranked_overlaps.append([overlap, melting_temp(overlap, salt_conc)])
    sorted_ranked_overlaps = sorted(ranked_overlaps, key=lambda x: x[1], reverse=False)
    selected_overlap = ''
    tolerance = 0
    while not selected_overlap:
        for b, ranked in enumerate(sorted_ranked_overlaps):
            if abs(ranked[1] - Tm) <= tolerance:
                selected_overlap = ranked
                break
        tolerance += 0.01
    overlap_seq = selected_overlap[0]
    location = ft.array2int(find_all_instances(overlap_seq, junction))
    if location < middle:
        junct_location = junction[location:middle] + "*" + junction[middle:location + len(overlap_seq)]
    else:
        junct_location = '*' + selected_overlap[0]
    selected_overlap.insert(1, junct_location)
    return selected_overlap


# Function takes in a sequence with the marked junction split location (*) and then determines the overlapping region
# and which direction it hangs
# I: <str> sequence
# O: <tuple> consisting of the <str> overlap region and if its on the <bool> right side of the junction (True- right,
#   False - Left)
def determine_overlapping_end(sequence):
    primer_length = len(sequence)
    ontheright = True
    divider_position = ft.array2int(find_all_instances("\*", sequence))
    if divider_position == primer_length - 1:
        OVERLAP = (sequence[0:primer_length - 1], False)
    else:
        OVERLAP = (sequence[divider_position + 1:primer_length], True)
    return OVERLAP


# function reads in a template, a marked junction ( primer region), melting temp, and salt_concentration to calculate
# primer extension (if necessary) to bring the binding region (sometimes exclusive of homology) to the appropriate melting
# temperature. function designs only the foward primer on the 5'-3' strand
# I: <str> template sequence, <str> marked junction  <float> melting temp <float> salt concentration
# O: <str> complete primer
def fwd_primer_extension(junction_template, marked_junction, Tm, salt_conc):
    middle = len(junction_template) / 2
    overlap = determine_overlapping_end(marked_junction)

    if overlap[1]:
        if marked_junction[0] == "*":
            finalprimer = ft.remove_str_annotation(overlap[0], "*")
            Tm = melting_temp(finalprimer,salt_conc)
        else:
            nd = middle + len(overlap[0])
            extended_primer = junction_template[middle:nd]
            extension = ""
            while abs(melting_temp(extended_primer, salt_conc) - Tm) >= 2:
                extended_primer += junction_template[nd]
                extension += junction_template[nd]
                nd += 1
            finalprimer = ft.remove_str_annotation(marked_junction + extension, "*")
            Tm = melting_temp(str(overlap[0]+extension), salt_conc)
    else:
        nd = middle
        extension = ""
        while abs(melting_temp(extension, salt_conc) - Tm) >= 1:
            extension += junction_template[nd]
            nd += 1
        finalprimer = ft.remove_str_annotation(marked_junction + extension, "*")
        Tm = melting_temp(extension, salt_conc)
    return (finalprimer, Tm)


# function reads in a template, a marked junction ( primer region), melting temp, and salt_concentration to calculate
# primer extension (if necessary) to bring the binding region (sometimes exclusive of homology) to the appropriate melting
# temperature. function designs only the reverse primer on the 3'-5' strand
# I: <str> template sequence, <str> marked junction  <float> melting temp <float> salt concentration
# O: <str> complete primer
def rev_primer_extension(junction_template, marked_junction, Tm, salt_conc):
    complement = ''
    complement_marked = ''
    for x in xrange(len(junction_template)):
        complement += return_nt_complement(junction_template[x])
    for y in xrange(len(marked_junction)):
        if marked_junction[y] == "*":
            complement_marked += "*"
        else:
            complement_marked += return_nt_complement(marked_junction[y])
    complement53 = ft.flipflop(complement)
    complement_marked53 = ft.flipflop(complement_marked)
    return fwd_primer_extension(complement53, complement_marked53, Tm, salt_conc)

